/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <algorithm>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "DexClass.h"
#include "DexInstruction.h"
#include "DexLoader.h"
#include "GlobalConfig.h"
#include "IRAssembler.h"
#include "IRCode.h"
#include "LiveRange.h"
#include "RClass.h"
#include "RedexContext.h"
#include "RedexResources.h"
#include "RedexTest.h"
#include "Show.h"
#include "ShowCFG.h"

namespace {
// SHOW on the code/cfg by default does not print out payloads; use the special
// printer;
void dump_code_verbose(IRCode* code) {
  auto& cfg = code->cfg();
  std::cout << show_res_payloads(cfg);
}

DexClass* get_r_class(const DexClasses& classes, const char* name) {
  DexClass* r_class = nullptr;
  for (const auto& cls : classes) {
    if (strcmp(cls->c_str(), name) == 0) {
      r_class = cls;
    }
  }
  always_assert_log(r_class != nullptr, "Did not find class %s!", name);
  auto* clinit = r_class->get_clinit();
  always_assert_log(clinit != nullptr, "%s should have a <clinit>", name);
  auto* code = clinit->get_code();
  always_assert_log(code != nullptr, "%s should have <clinit> code", name);
  return r_class;
}

void dump_field_values(const resources::FieldArrayValues& field_values) {
  for (auto&& [f, vec] : field_values) {
    std::cerr << SHOW(f) << " -> { ";
    bool first = true;
    for (const auto i : vec) {
      if (!first) {
        std::cerr << ", ";
      }
      std::cerr << "0x" << std::hex << i << std::dec;
      first = false;
    }
    std::cerr << " }" << "\n";
  }
}

// For use in analyzing an autogenerated R class <clinit>, find the const
// instruction that provides the size of the new-array.
IRInstruction* find_new_array_size_insn(live_range::UseDefChains* use_defs,
                                        IRInstruction* new_array_insn) {
  auto& const_defs = use_defs->at(live_range::Use{new_array_insn, 0});
  always_assert_log(const_defs.size() == 1, "Should be 1 def for insn %s",
                    SHOW(new_array_insn));
  auto* const_def = *const_defs.begin();
  always_assert_log(const_def->opcode() == OPCODE_CONST,
                    "Array size expected to be a const instruction. Got %s",
                    SHOW(const_def));
  return const_def;
}

IRInstruction* find_fill_array_data_use(
    const UnorderedSet<live_range::Use>& uses) {
  std::vector<IRInstruction*> matched;
  for (const auto& u : UnorderedIterable(uses)) {
    if (u.insn->opcode() == OPCODE_FILL_ARRAY_DATA) {
      matched.push_back(u.insn);
    }
  }
  always_assert_log(matched.size() == 1,
                    "Did not find exactly one expected use");
  return matched.front();
}

// Build cfg, set rstate as we expect to simulate outlined methods
void prepare_methods_for_test(DexClasses& classes) {
  for (auto* cls : classes) {
    for (auto* m : cls->get_all_methods()) {
      auto* code = m->get_code();
      if (code != nullptr) {
        code->build_cfg();
      }
    }
  }
}
} // namespace

const char* base_r_class_name = "Lcom/redextest/R;";
const char* styleable_r_class_name = "Lcom/redextest/R$styleable;";
const char* another_styleable_r_class_name = "Lcom/redextest/R$styleable2;";
const char* styleable_sgets_r_class_name = "Lcom/redextest/R$styleable_sgets;";

class RClassTest : public RedexIntegrationTest {
 public:
  ResourceConfig global_resources_config;
  ResourceConfig instrumented_global_resources_config;
  DexClass* base_r_class;

  RClassTest() {
    prepare_methods_for_test(*classes);
    // The outer R class is assumed to have been customized to store extra junk,
    // like in our buck builds of applications.
    global_resources_config.customized_r_classes.emplace(base_r_class_name);
    instrumented_global_resources_config.customized_r_classes.emplace(
        base_r_class_name);
    instrumented_global_resources_config.cleanup_r_class_rewriting = false;
    base_r_class = get_r_class(*classes, base_r_class_name);
  }
};

TEST_F(RClassTest, extractStaticArrayValues) {
  resources::RClassReader r_class_reader(global_resources_config);
  // Basic check on returning only IDs related to given fields.
  {
    UnorderedSet<uint32_t> values;
    r_class_reader.extract_resource_ids_from_static_arrays({base_r_class}, {},
                                                           &values);
    EXPECT_EQ(values.size(), 0);
  }
  {
    DexFieldRef* ref = DexField::get_field("Lcom/redextest/R;.one:[I");
    UnorderedSet<uint32_t> values;
    r_class_reader.extract_resource_ids_from_static_arrays(
        {base_r_class}, {ref->as_def()}, &values);
    EXPECT_THAT(unordered_unsafe_unwrap(values),
                testing::UnorderedElementsAre(0x7f010000, 0x7f010001,
                                              0x7f010002, 0x7f010003));
  }
  {
    DexFieldRef* ref_a = DexField::get_field("Lcom/redextest/R;.two:[I");
    DexFieldRef* ref_b = DexField::get_field("Lcom/redextest/R;.three:[I");
    UnorderedSet<uint32_t> values;
    r_class_reader.extract_resource_ids_from_static_arrays(
        {base_r_class}, {ref_a->as_def(), ref_b->as_def()}, &values);
    EXPECT_THAT(unordered_unsafe_unwrap(values),
                testing::UnorderedElementsAre(0x7f020000, 0x7f020001,
                                              0x7f020002, 0x7f020003,
                                              0x7f030000, 0x7f030001));
  }

  DexClass* another_styleable_r_class =
      get_r_class(*classes, another_styleable_r_class_name);
  {
    DexFieldRef* ref =
        DexField::get_field("Lcom/redextest/R$styleable2;.five:[I");
    UnorderedSet<uint32_t> values;
    r_class_reader.extract_resource_ids_from_static_arrays(
        {another_styleable_r_class}, {ref->as_def()}, &values);
    EXPECT_THAT(unordered_unsafe_unwrap(values),
                testing::UnorderedElementsAre(0x7f050000,
                                              0x7f050001,
                                              0x7f050002,
                                              0x7f050003,
                                              0x7f050004,
                                              0x7f050005,
                                              0x7f050006,
                                              0x7f050007,
                                              0x7f050008,
                                              0x7f050009));
  }

  {
    DexFieldRef* ref =
        DexField::get_field("Lcom/redextest/R$styleable_sgets;.seven:[I");
    UnorderedSet<uint32_t> values;
    r_class_reader.extract_resource_ids_from_static_arrays(
        *classes, {ref->as_def()}, &values);
    EXPECT_THAT(unordered_unsafe_unwrap(values),
                testing::UnorderedElementsAre(0x7f040000, 0x7f040001))
        << "seven is incorrect";
  }
}

TEST_F(RClassTest, analyzeStaticInitializers) {
  dump_code_verbose(base_r_class->get_clinit()->get_code());

  DexClass* styleable_class = get_r_class(*classes, styleable_r_class_name);
  dump_code_verbose(styleable_class->get_clinit()->get_code());

  resources::RClassReader r_class_reader(global_resources_config);

  auto field_values = r_class_reader.analyze_clinit(base_r_class, {});
  dump_field_values(field_values);

  EXPECT_EQ(field_values.size(), 4);
  {
    auto it = field_values.begin();
    EXPECT_STREQ(SHOW(it->first), "Lcom/redextest/R;.one:[I");
    EXPECT_THAT(
        it->second,
        testing::ElementsAre(0x7f010000, 0x7f010001, 0x7f010002, 0x7f010003));
    it++;
    EXPECT_STREQ(SHOW(it->first), "Lcom/redextest/R;.six:[I");
    EXPECT_THAT(it->second,
                testing::ElementsAre(0x7f060000,
                                     0x7f060001,
                                     0x7f060002,
                                     0x7f060003,
                                     0x7f060004,
                                     0x7f060005,
                                     0x7f060006,
                                     0x7f060007,
                                     0x7f060008,
                                     0x7f060009));
    it++;
    EXPECT_STREQ(SHOW(it->first), "Lcom/redextest/R;.three:[I");
    EXPECT_THAT(it->second, testing::ElementsAre(0x7f030000, 0x7f030001));
    it++;
    EXPECT_STREQ(SHOW(it->first), "Lcom/redextest/R;.two:[I");
    EXPECT_THAT(
        it->second,
        testing::ElementsAre(0x7f020000, 0x7f020001, 0x7f020002, 0x7f020003));
  }

  field_values.clear();

  auto styleable_field_values =
      r_class_reader.analyze_clinit(styleable_class, {});
  dump_field_values(styleable_field_values);

  EXPECT_EQ(styleable_field_values.size(), 1);
  EXPECT_STREQ(SHOW(styleable_field_values.begin()->first),
               "Lcom/redextest/R$styleable;.four:[I");
  {
    auto array_values = styleable_field_values.begin()->second;
    EXPECT_EQ(array_values.size(), 2);
    EXPECT_THAT(array_values, testing::ElementsAre(0x7f040000, 0x7f040001));
  }
}

TEST_F(RClassTest, remapResourceClassArrays) {
  std::cout << "BASELINE R <clinit>:" << "\n";
  auto* clinit = base_r_class->get_clinit();
  auto* code = clinit->get_code();
  dump_code_verbose(code);

  // A typical styleable inner class, which has different conventions and is
  // indexed directly into. Deletion should instead insert zeros.
  DexClass* styleable_class = get_r_class(*classes, styleable_r_class_name);
  std::cout << "\n"
            << "BASELINE R$styleable <clinit>:" << "\n";
  auto* styleable_clinit = styleable_class->get_clinit();
  auto* styleable_code = styleable_clinit->get_code();
  dump_code_verbose(styleable_code);

  // Just another class with arrays, so tests can be written against other valid
  // opcode sequences for filling arrays statically.
  DexClass* another_styleable_r_class =
      get_r_class(*classes, another_styleable_r_class_name);
  std::cout << "\n"
            << "BASELINE R$styleable2 <clinit>:" << "\n";
  auto* another_clinit = another_styleable_r_class->get_clinit();
  auto* another_code = another_clinit->get_code();
  dump_code_verbose(another_code);

  std::map<uint32_t, uint32_t> old_to_remapped_ids;
  // Remap all 4 items in the first array.
  old_to_remapped_ids.emplace(0x7f010000, 0x7f010010);
  old_to_remapped_ids.emplace(0x7f010001, 0x7f010011);
  old_to_remapped_ids.emplace(0x7f010002, 0x7f010012);
  old_to_remapped_ids.emplace(0x7f010003, 0x7f010013);
  // Keep the first two items from the second array, and delete the last 2.
  old_to_remapped_ids.emplace(0x7f020000, 0x7f020000);
  old_to_remapped_ids.emplace(0x7f020001, 0x7f020001);
  // Keep the first item from the third array, delete the last.
  old_to_remapped_ids.emplace(0x7f030000, 0x7f030000);
  // For styleable, delete first and keep last
  old_to_remapped_ids.emplace(0x7f040001, 0x7f040001);
  // As above, but with a larger array that would be generated with a different
  // sequence of opcodes. Delete the first ID 0x7f050000 and keep the rest.
  for (uint32_t i = 0x7f050001; i <= 0x7f050009; i++) {
    old_to_remapped_ids.emplace(i, i);
  }
  // Non-styleable big array deletion. Delete the first and keep the rest.
  for (uint32_t i = 0x7f060001; i <= 0x7f060009; i++) {
    old_to_remapped_ids.emplace(i, i);
  }

  resources::RClassWriter r_class_writer(global_resources_config);
  r_class_writer.remap_resource_class_arrays(stores, old_to_remapped_ids);

  // Take the string representation of the field, and its expected array size to
  // ensure that a sensible SPUT was emitted.
  auto verify_expected_sizes =
      [&](cfg::ControlFlowGraph& cfg,
          const std::unordered_map<std::string, uint32_t>& expected_sizes,
          const std::function<void(const std::string& field_name,
                                   const std::vector<uint32_t>& payload)>&
              payload_callback) {
        std::unordered_set<std::string> seen;
        live_range::MoveAwareChains move_aware_chains(cfg);
        auto use_defs = move_aware_chains.get_use_def_chains();
        auto def_uses = move_aware_chains.get_def_use_chains();
        for (const auto& mie : cfg::InstructionIterable(cfg)) {
          auto* insn = mie.insn;
          if (insn->opcode() == OPCODE_SPUT_OBJECT) {
            auto field_name = show(insn->get_field());
            auto expected_size = expected_sizes.at(field_name);
            auto& array_defs = use_defs.at(live_range::Use{insn, 0});
            EXPECT_EQ(array_defs.size(), 1);
            // Make sure the array was created with the proper constant size.
            auto* array_def = *array_defs.begin();
            auto* const_def = find_new_array_size_insn(&use_defs, array_def);
            EXPECT_EQ(const_def->get_literal(), expected_size);
            // Make sure the payload that fills the array is the proper size too
            // and optionall dispatch a callback for further asserts
            auto array_uses = def_uses.at(array_def);
            auto* fill_array_data_insn = find_fill_array_data_use(array_uses);
            auto* op_data = fill_array_data_insn->get_data();
            auto payload = get_fill_array_data_payload<uint32_t>(op_data);
            EXPECT_EQ(payload.size(), expected_size);
            payload_callback(field_name, payload);
            seen.emplace(field_name);
          }
        }
        EXPECT_EQ(seen.size(), expected_sizes.size())
            << "Did not find expected number of new array opcodes!";
      };

  std::cout << "\n"
            << "MODIFIED R <clinit>:" << "\n";
  dump_code_verbose(code);

  auto no_further_checks = [](const std::string& field_name,
                              const std::vector<uint32_t>& payload) {};
  verify_expected_sizes(code->cfg(),
                        {{"Lcom/redextest/R;.one:[I", 4},
                         {"Lcom/redextest/R;.two:[I", 2},
                         {"Lcom/redextest/R;.three:[I", 1},
                         {"Lcom/redextest/R;.six:[I", 9}},
                        no_further_checks);

  std::cout << "\n"
            << "MODIFIED R$styleable <clinit>:" << "\n";
  dump_code_verbose(styleable_code);
  verify_expected_sizes(styleable_code->cfg(),
                        {{"Lcom/redextest/R$styleable;.four:[I", 2}},
                        no_further_checks);

  std::cout << "\n"
            << "MODIFIED R$styleable2 <clinit>:" << "\n";
  dump_code_verbose(another_code);
  auto callback = [](const std::string& /*field_name*/,
                     const std::vector<uint32_t>& payload) {
    EXPECT_EQ(payload[0], 0) << "First element should be zeroed out";
    EXPECT_EQ(payload[1], 0x7f050001) << "Second element should remain intact";
  };
  verify_expected_sizes(another_code->cfg(),
                        {{"Lcom/redextest/R$styleable2;.five:[I", 10}},
                        callback);
}

TEST_F(RClassTest, noDoubleRemappingArrays) {
  // This setup ensures that the value does not get remapped twice.
  constexpr uint32_t TARGET_ID = 0x7f090000;
  std::map<uint32_t, uint32_t> old_to_remapped_ids{{0x7f040000, TARGET_ID},
                                                   {TARGET_ID, 0x7f099999}};
  resources::RClassWriter r_class_writer(global_resources_config);
  r_class_writer.remap_resource_class_arrays(stores, old_to_remapped_ids);

  DexClass* sget_r_class = get_r_class(*classes, styleable_sgets_r_class_name);
  auto* clinit = sget_r_class->get_clinit();
  auto* code = clinit->get_code();
  dump_code_verbose(code);
  bool found_array{false};
  for (const auto& mie : cfg::InstructionIterable(code->cfg())) {
    auto* insn = mie.insn;
    if (insn->opcode() == OPCODE_FILL_ARRAY_DATA) {
      found_array = true;
      auto* op_data = insn->get_data();
      auto payload = get_fill_array_data_payload<uint32_t>(op_data);
      EXPECT_EQ(payload.size(), 2) << "Should have two elements!";
      EXPECT_EQ(payload[0], TARGET_ID) << "Remapping was incorrect!";
      EXPECT_EQ(payload[1], 0) << "Second values should be zeroed out!";
    }
  }
  EXPECT_TRUE(found_array);
}

TEST_F(RClassTest, bePermissiveWhenInInstrumentationMode) {
  auto* instrumented_r_cls = assembler::class_from_string(R"(
    (class (public) "Lcom/redextest/R$styleable_instr;"
      (field (public static final) "Lcom/redextest/R$styleable_instr;.eight:[I")
      (method (static constructor) "Lcom/redextest/R$styleable_instr;.<clinit>:()V"
        (
          (const v0 999)
          (invoke-static (v0) "Lcom/redex/Instrumentation;.onMethodBegin:(I)V")
          (const v0 2)
          (new-array v0 "[I") ; create an array of length 2
          (move-result-pseudo-object v1)
          (fill-array-data v1 #4 (7f080000 7f080001))
          (sput-object v1 "Lcom/redextest/R$styleable_instr;.eight:[I")
          (return-void)
        )
      )
      (method (constructor) "Lcom/redextest/R$styleable_instr;.<init>:()V"
        (
          (load-param-object v0)
          (const v1 1000)
          (invoke-static (v1) "Lcom/redex/Instrumentation;.onMethodBegin:(I)V")
          (invoke-direct (v0) "Ljava/lang/Object;.<init>:()V")
          (return-void)
        )
      )
    )
  )");

  // Build up a totally separate store, scope of the specially instrumented
  // class, so that this example, and the config options can be tested in
  // isolation from eachother.
  Scope test_scope{instrumented_r_cls};
  prepare_methods_for_test(test_scope);

  DexMetadata temp_metadata;
  temp_metadata.set_id("classes");
  DexStore temp_store(temp_metadata);
  temp_store.add_classes(test_scope);
  std::vector<DexStore> temp_stores{temp_store};

  auto* code = instrumented_r_cls->get_clinit()->get_code();
  std::cout << "Instrumented R clinit:" << "\n";
  dump_code_verbose(code);

  std::map<uint32_t, uint32_t> old_to_remapped_ids{{0x7f080000, 0x7f090000},
                                                   {0x7f080001, 0x7f090001}};

  resources::RClassWriter r_class_writer(instrumented_global_resources_config);
  r_class_writer.remap_resource_class_arrays(temp_stores, old_to_remapped_ids);

  std::cout << "Remapped Instrumented R clinit:" << "\n";
  dump_code_verbose(code);
  size_t found_count{0};
  for (const auto& mie : cfg::InstructionIterable(code->cfg())) {
    auto* insn = mie.insn;
    if (insn->opcode() == OPCODE_FILL_ARRAY_DATA) {
      found_count++;
      if (found_count == 2) {
        auto* op_data = insn->get_data();
        auto payload = get_fill_array_data_payload<uint32_t>(op_data);
        EXPECT_EQ(payload.size(), 2) << "Should have two elements!";
        EXPECT_EQ(payload[0], 0x7f090000) << "Remapping was incorrect!";
        EXPECT_EQ(payload[1], 0x7f090001) << "Remapping was incorrect!";
      }
    }
  }
  EXPECT_EQ(found_count, 2) << "Cleanup should have been turned off!";
}
